import argparse
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, repeat
import itertools


def chunk_state_triton(B, x, dt, dA_cumsum):
    from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd
    return _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=False)


def ref_program(B, x, dt, dA_cumsum):
    """
    Argument:
        B: (batch, seqlen, ngroups, headdim)
        x: (batch, seqlen, nheads, headdim)
        dt: (batch, nheads, nchunks, chunk_size)
        dA_cumsum: (batch, nheads, nchunks, chunk_size)
    Return:
        states: (batch, nchunks, nheads, headdim, dstate)
    """
    # Check constraints.
    batch, seqlen, nheads, headdim = x.shape
    dstate = B.shape[-1]
    _, _, nchunks, chunk_size = dt.shape
    assert seqlen <= nchunks * chunk_size
    assert x.shape == (batch, seqlen, nheads, headdim)
    assert dt.shape == (batch, nheads, nchunks, chunk_size)
    ngroups = B.shape[2]
    assert nheads % ngroups == 0
    assert B.shape == (batch, seqlen, ngroups, dstate)
    B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
    assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
    if seqlen < nchunks * chunk_size:
        x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
        B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
    x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
    B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
    decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
    return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype),
                        dt.to(x.dtype), x)


def get_configs():
    iter_params = dict(
        block_M=[64, 128], block_N=[32, 64, 128], block_K=[32, 64], num_stages=[1, 2, 3, 4, 5])
    return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]


@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[4])
def chunk_state_fwd(batch,
                    seqlen,
                    chunk_size,
                    ngroups,
                    nheads,
                    headdim,
                    dstate,
                    block_M=64,
                    block_N=64,
                    block_K=64,
                    num_stages=2,
                    threads=128):
    dtype = "float16"
    accum_dtype = "float"
    nchunks = T.ceildiv(seqlen, chunk_size)
    p = 1.44269504

    @T.prim_func
    def main(B: T.Tensor((batch, seqlen, ngroups, dstate), dtype), x: T.Tensor(
        (batch, seqlen, nheads, headdim), dtype), dt: T.Tensor(
            (batch, nheads, nchunks, chunk_size), dtype), dA_cumsum: T.Tensor(
                (batch, nheads, nchunks, chunk_size), dtype), Output: T.Tensor(
                    (batch, nchunks, nheads, headdim, dstate), dtype)):
        with T.Kernel(
                nheads,
                T.ceildiv(headdim, block_M) * T.ceildiv(dstate, block_N),
                batch * nchunks,
                threads=threads) as (bz, bx, by):
            x_shared = T.alloc_shared((block_K, block_M), dtype)
            x_local = T.alloc_fragment((block_K, block_M), dtype)
            xt_local = T.alloc_fragment((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            dt_shared = T.alloc_shared((block_K), dtype)
            dA_cumsum_shared = T.alloc_shared((block_K), dtype)
            acc_o = T.alloc_fragment((block_M, block_N), accum_dtype)
            acc_o_shared = T.alloc_shared((block_M, block_N), dtype)
            scale = T.alloc_fragment((block_K), accum_dtype)
            dA_cs_last = T.alloc_fragment((1), accum_dtype)
            dA_cumsum_local = T.alloc_fragment((block_K), accum_dtype)
            dt_local = T.alloc_fragment((block_K), accum_dtype)

            loop_range = T.ceildiv(chunk_size, block_K)

            batch_idx = by % batch
            chunk_idx = by // batch
            m_idx = bx // T.ceildiv(dstate, block_N)
            n_idx = bx % T.ceildiv(dstate, block_N)

            T.annotate_layout({
                x_shared: tilelang.layout.make_swizzled_layout(x_shared),
                acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared)
            })

            dA_cs_last[0] = dA_cumsum[batch_idx, bz, chunk_idx, chunk_size - 1]
            T.clear(acc_o)
            for k in T.Pipelined(loop_range, num_stages=num_stages):
                T.copy(
                    x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
                      (k + 1) * block_K, bz, m_idx * block_M:(m_idx + 1) * block_M], x_shared)
                T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K],
                       dA_cumsum_shared)
                T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared)
                T.copy(dA_cumsum_shared, dA_cumsum_local)
                T.copy(dt_shared, dt_local)
                for i in T.Parallel(block_K):
                    scale[i] = T.exp2(dA_cs_last[0] * p - dA_cumsum_local[i] * p) * dt_local[i]
                T.copy(x_shared, x_local)
                for i, j in T.Parallel(block_M, block_K):
                    xt_local[i, j] = x_local[j, i] * scale[j]
                T.copy(
                    B[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size +
                      (k + 1) * block_K, bz // (nheads // ngroups),
                      n_idx * block_N:(n_idx + 1) * block_N], B_shared)
                T.gemm(xt_local, B_shared, acc_o)
            T.copy(acc_o, acc_o_shared)
            T.copy(
                acc_o_shared,
                Output[batch_idx, chunk_idx, bz, m_idx * block_M:(m_idx + 1) * block_M,
                       n_idx * block_N:(n_idx + 1) * block_N])

    return main


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=8, help='batch size')
    parser.add_argument('--heads', type=int, default=80, help='heads')
    parser.add_argument('--groups', type=int, default=1, help='groups')
    parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
    parser.add_argument('--chunk_size', type=int, default=256, help='chunk size')
    parser.add_argument('--dim', type=int, default=64, help='dim')
    parser.add_argument('--dstate', type=int, default=128, help='dstate')
    parser.add_argument('--tune', action='store_true', help='tune configs')
    args = parser.parse_args()
    batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
    total_flops = 2 * batch * seq_len * heads * dim * dstate

    if (not args.tune):
        kernel = chunk_state_fwd(
            batch,
            seq_len,
            chunk_size,
            groups,
            heads,
            dim,
            dstate,
            block_M=64,
            block_N=128,
            block_K=64,
            num_stages=4,
            threads=128)
        profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal)
        profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
        print("All checks pass.")
        latency = profiler.do_bench(ref_program, warmup=500)
        print("Ref: {:.2f} ms".format(latency))
        print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
        latency = profiler.do_bench(warmup=500)
        print("Tile-lang: {:.2f} ms".format(latency))
        print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
    else:
        best_result = chunk_state_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
        best_latency = best_result.latency
        best_config = best_result.config
        ref_latency = best_result.ref_latency
        print(f"Best latency: {best_latency}")
        print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
        print(f"Best config: {best_config}")
